import os
import numpy as np
import argparse
import matplotlib.pyplot as plt
import rosbag
from utils import get_file_in_dir

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="visulize timestamps of multiple topics from multiple rosbags")
    project_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
    data_dir = os.path.join(project_dir, 'data')
    output_dir = os.path.join(project_dir, 'output')
    parser.add_argument('--input-dir', "-i", type=str, default=data_dir, help="Input data directory (default: data/)")
    parser.add_argument('--topics', type=str, nargs='+', default=[], help="Topic names for visualization")
    args = parser.parse_args()
    bag_files = get_file_in_dir(args.input_dir, '.bag')
    print(bag_files)
    # 存储时间戳
    timestamps = {}
    sequences = {}
    min_timestamp = float('inf')
    for bag_file in bag_files:
        bag = rosbag.Bag(bag_file)
        for topic in args.topics:
            if topic not in timestamps:
                timestamps[topic] = []
                sequences[topic] = []
            for _, msg, _ in bag.read_messages(topics=[topic]):
                timestamp = msg.header.stamp.to_sec()
                if timestamp < min_timestamp:
                    min_timestamp = timestamp

                timestamps[topic].append(timestamp)
                sequences[topic].append(msg.header.seq)
        bag.close()
    plt.figure(figsize=(20, len(args.topics) * 0.5 + 1.0))
    # 绘制时间戳
    level = 0.0
    for topic, timestamp in timestamps.items():
        plt.scatter(timestamp, np.ones(len(timestamp)) * level, label=f'{topic}', marker='|', s=1000)
        level = level + 1.0

    plt.xlim(min_timestamp, min_timestamp + 2.0)
    plt.ylim(-1.0, len(args.topics) * 1.0 + 1.0)

    plt.xlabel('Timestamp (seconds)')
    plt.title('Timestamps of multiple topics from multiple rosbags')
    plt.legend(loc='bottom right', ncol=len(args.topics))
    plt.show()
